from typing import Dict, Iterator, List, Tuple, Optional
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from dataclasses import dataclass

class Text:
    def __init__(self, text: str):
        self.text = text
        self.words = Text._text_to_words(text)

    def __str__(self):
        return self.text

    def __repr__(self):
        return self.text

    def __len__(self):
        return len(self.words)

    def __getitem__(self, i):
        return self.words[i]

    def replace(self, indx: int, repl: str) -> "Text":
        new_words = self.words[:indx] + repl.split() + self.words[indx + 1:]
        return Text(Text._words_to_text(new_words))

    def generate_maskings(self, mask_word: str = '[MASK]') -> Iterator[Tuple[int, "Text"]]:
        for i in range(len(self)):
            yield (i, self.replace(i, mask_word))

    @staticmethod
    def _text_to_words(text: str) -> List[str]:
        return text.split(' ')

    @staticmethod
    def _words_to_text(words: List[str]) -> str:
        return ' '.join(words)
    

class Score:
    def __init__(
        self,
        model,
        layer: int,
        target_embedding: np.ndarray,
        original_text: str,
        device: str = "cuda:0"
    ):
        self.model = model
        self.target_embedding = target_embedding
        self.original_text = original_text
        self.device = device
        self.layer = layer
        self.original_embedding = self._get_embedding(original_text)
        

    def _get_embedding(self, text: str) -> np.ndarray:
        
        with torch.no_grad():
            hidden_states = self.model.get_representations(text)    
        return hidden_states[:, self.layer, :].squeeze().cpu().numpy() 
    
    def compute_distance(self, text: str) -> float:
        embedding = self._get_embedding(str(text))
        return float(np.linalg.norm(embedding - self.target_embedding))

    def score(self, text: str) -> float:
        
        distance = self.compute_distance(str(text))
        score = -distance
            
        return score

